import os
import subprocess
import pybedtools
from Bio import Entrez

organism = "Homo sapiens"

class Transcript:
    def __init__(self):
        self.accession = None
        self.chromosome = None
        self.strand = None
        self.exons = []
        self.transcript_type = None

class Transcripts(list):
    def __init__(self):
        self.gene_id = None
        self.gene_name  = None
        self.gene_type = None

def find_transcript_locations(record, organism, mirnas=None, primirnas=None):
    chromosomes = {'NC_000001': 'chr1',
                   'NC_000002': 'chr2',
                   'NC_000003': 'chr3',
                   'NC_000004': 'chr4',
                   'NC_000005': 'chr5',
                   'NC_000006': 'chr6',
                   'NC_000007': 'chr7',
                   'NC_000008': 'chr8',
                   'NC_000009': 'chr9',
                   'NC_000010': 'chr10',
                   'NC_000011': 'chr11',
                   'NC_000012': 'chr12',
                   'NC_000013': 'chr13',
                   'NC_000014': 'chr14',
                   'NC_000015': 'chr15',
                   'NC_000016': 'chr16',
                   'NC_000017': 'chr17',
                   'NC_000018': 'chr18',
                   'NC_000019': 'chr19',
                   'NC_000020': 'chr20',
                   'NC_000021': 'chr21',
                   'NC_000022': 'chr22',
                   'NC_000023': 'chrX',
                   'NC_000024': 'chrY',
                   'NC_012920': 'chrM',
                  }
    transcripts = Transcripts()
    if record['Entrezgene_source']['BioSource']['BioSource_org']['Org-ref']['Org-ref_taxname']!=organism:
        return
    gene_track = record['Entrezgene_track-info']['Gene-track']
    status = gene_track['Gene-track_status'].attributes['value']
    if status in ('discontinued', 'secondary'):
        return
    assert status=='live'
    gene_type = record['Entrezgene_type'].attributes['value']
    gene_ref = record['Entrezgene_gene']['Gene-ref']
    gene_name = gene_ref.get('Gene-ref_locus')
    if not gene_name:
        gene_name = gene_ref['Gene-ref_syn'][0]
    print(gene_name)
    if gene_type=='unknown':
        assert record['Entrezgene_type']=='0'
    elif gene_type=='tRNA':
        assert record['Entrezgene_type']=='1'
        return
    elif gene_type=='rRNA':
        assert record['Entrezgene_type']=='2'
        return
    elif gene_type=='snRNA':
        assert record['Entrezgene_type']=='3'
    elif gene_type=='scRNA':
        assert record['Entrezgene_type']=='4'
    elif gene_type=='snoRNA':
        assert record['Entrezgene_type']=='5'
    elif gene_type=='protein-coding':
        assert record['Entrezgene_type']=='6'
    elif gene_type=='pseudo':
        assert record['Entrezgene_type']=='7'
    elif gene_type=='miscRNA':
        assert record['Entrezgene_type']=='9'
    elif gene_type=='ncRNA':
        assert record['Entrezgene_type']=='10'
    elif gene_type=='biological-region':
        assert record['Entrezgene_type']=='11'
        return
    elif gene_type=='other':
        assert record['Entrezgene_type']=='255'
    else:
        comment = "Unknown gene type %s (%s) for gene %s" % (gene_type, record['Entrezgene_type'], gene_name)
        raise Exception(comment)
    transcripts.gene_id = gene_track['Gene-track_geneid']
    transcripts.gene_type = gene_type
    transcripts.gene_name = gene_name
    description = gene_ref.get('Gene-ref_desc', '')
    if 'ribosomal protein' in description:
        # Also include ribonucleoprotein?
        transcripts.description = "ribosomal-protein"
    elif 'histone family' in description or "histone cluster" in description:
        transcripts.description = "histone"
    elif description.startswith("microRNA"):
        if gene_name.startswith("MIRLET"):
            assert gene_name[6] in '0123456789'
        elif gene_name.startswith("MIR"):
            assert gene_name[3] in '0123456789'
        else:
            raise Exception("microRNA with name %s" % gene_name)
        transcripts.description = 'microRNA'
        if mirnas is not None:
            mirnas.append(gene_name)
    elif gene_name.startswith("MIR"):
        if gene_name.endswith("-AS1"):
            assert 'antisense RNA 1' in description
            transcripts.description = 'antisense'
        else:
            assert gene_name.endswith("HG")
            assert description.endswith(" host gene (non-protein coding)") or description.endswith(" host gene")
            if transcripts.gene_type not in ('ncRNA', 'miscRNA', 'protein-coding'):
                # MIR205HG has gene type "protein coding" though the description is "MIR205 host gene (non-protein coding)"
                raise Exception("Unexpected gene type for primicroRNA %s (%s)" % (gene_name, transcripts.gene_type))
            transcripts.description = 'primicroRNA'
            if primirnas is not None:
                primirnas.append(gene_name)
    elif 'antisense' in description:
        transcripts.description = 'antisense'
    else:
        transcripts.description = 'none'
    annotated = True
    for comment in record['Entrezgene_comments']:
        if comment.get('Gene-commentary_heading')!='Annotation Information':
            continue
        for property in comment['Gene-commentary_properties']:
            if property['Gene-commentary_text']=='not annotated on reference assembly':
                annotated = False
    if not annotated:
        return transcripts
    for gene_commentary in record.get('Entrezgene_locus', []):
        if not "Gene-commentary_heading" in gene_commentary:
            continue
        if not gene_commentary['Gene-commentary_accession'].startswith("NC_00"):
            continue

        if gene_commentary['Gene-commentary_heading']!='Reference GRCh38.p13 Primary Assembly':
            continue
        chromosome = chromosomes[gene_commentary['Gene-commentary_accession']]
        assert chromosome[:3]=='chr'
        assert len(gene_commentary['Gene-commentary_seqs'])==1
        loc_int = gene_commentary['Gene-commentary_seqs'][0]['Seq-loc_int']
        interval = loc_int['Seq-interval']
        start = int(interval['Seq-interval_from'])
        end = int(interval['Seq-interval_to'])+1
        strand = interval['Seq-interval_strand']['Na-strand'].attributes['value']
        if strand=='plus':
            strand = '+'
        elif strand=='minus':
            strand = '-'
        else:
            raise Exception("Unknown strand %s" % strand)
        products = gene_commentary.get('Gene-commentary_products')
        if not products:
            transcript = Transcript()
            transcript.chromosome = chromosome
            transcript.strand = strand
            transcript.exons = [[start, end]]
            transcript.accession = 'unknown'
            transcript.transcript_type = transcripts.gene_type
            transcripts.append(transcript)
            continue
        for product in products:
            transcript = Transcript()
            transcript.transcript_type = transcripts.gene_type
            if product['Gene-commentary_type']=="1":
                assert product['Gene-commentary_type'].attributes['value']=='genomic'
                continue
            elif product['Gene-commentary_type']=="2":
                assert product['Gene-commentary_type'].attributes['value']=='pre-RNA'
            elif product['Gene-commentary_type']=="3":
                assert product['Gene-commentary_type'].attributes['value']=='mRNA'
            elif product['Gene-commentary_type']=="5":
                assert product['Gene-commentary_type'].attributes['value']=='tRNA'
                continue
            elif product['Gene-commentary_type']=="11":
                assert product['Gene-commentary_type'].attributes['value']=='biological-region'
                continue
            elif product['Gene-commentary_type']=="14":
                assert product['Gene-commentary_type'].attributes['value']=='miscRNA'
                transcript.transcript_type = 'miscRNA'
            elif product['Gene-commentary_type']=="22":
                assert product['Gene-commentary_type'].attributes['value']=='ncRNA'
                transcript.transcript_type = 'ncRNA'
            elif product['Gene-commentary_type']=="26":
                assert product['Gene-commentary_type'].attributes.get('value') in ("c-region", None)  # constant part of an antibody
                continue
            elif product['Gene-commentary_type']=="27":
                assert product['Gene-commentary_type'].attributes.get('value') in ("d-segment", None)  # d-segment of an antibody
                continue
            elif product['Gene-commentary_type']=="28":
                assert product['Gene-commentary_type'].attributes.get('value') in ("j-segment", None)  # j-segment of an antibody
                continue
            elif product['Gene-commentary_type']=="29":
                assert product['Gene-commentary_type'].attributes.get('value') in ("v-region", None)  # variable part of an antibody
                continue
            else:
                print(product['Gene-commentary_type'].attributes['value'])
                assert product['Gene-commentary_type'].attributes['value']=='rRNA'
                continue
            transcript.accession = product['Gene-commentary_accession']
            transcripts.append(transcript)
            transcript.chromosome = chromosome
            transcript.strand = strand
            for coordinates in product['Gene-commentary_genomic-coords']:
                if 'Seq-loc_mix' in coordinates:
                    seqlocs = coordinates['Seq-loc_mix']['Seq-loc-mix']
                    for seqloc in seqlocs:
                        interval = seqloc['Seq-loc_int']['Seq-interval']
                        start = int(interval['Seq-interval_from'])
                        end = int(interval['Seq-interval_to']) + 1
                        if strand=='+':
                            assert interval['Seq-interval_strand']['Na-strand'].attributes['value']=='plus'
                        elif strand=='-':
                            assert interval['Seq-interval_strand']['Na-strand'].attributes['value']=='minus'
                        transcript.exons.append([start, end])
                else:
                    interval = coordinates['Seq-loc_int']['Seq-interval']
                    start = int(interval['Seq-interval_from'])
                    end = int(interval['Seq-interval_to']) + 1
                    if strand=='+':
                        assert interval['Seq-interval_strand']['Na-strand'].attributes['value']=='plus'
                    elif strand=='-':
                        assert interval['Seq-interval_strand']['Na-strand'].attributes['value']=='minus'
                    transcript.exons.append([start, end])
            exons = sorted(transcript.exons)
            current = -1
            for exon in exons:
                assert exon[1] > exon[0]
                assert exon[0] >= current # some introns are zero-length
                current = exon[1]
            if strand=='+':
                transcript.exons = exons
            elif strand=='-':
                transcript.exons = exons[::-1]
    return transcripts


def timestamp(filename):
    import time, os
    t = os.path.getmtime(filename)
    return time.strftime("%Y.%m.%d", time.localtime(t))

def order(interval):
    chromosome = interval.chrom
    if chromosome=='chrM':
        chromosome = 100
    elif chromosome=='chrY':
        chromosome = 99
    elif chromosome=='chrX':
        chromosome = 98
    else:
        chromosome = int(chromosome[3:])
        assert chromosome < 98
    start = interval.start
    end = interval.end
    strand = interval.strand
    return (chromosome, start, end, strand)

def read_record(transcripts):
    gene_id = transcripts.gene_id
    gene_name = transcripts.gene_name
    gene_type = transcripts.gene_type
    description = transcripts.description
    # A gene can be mapped to multiple locations. In particular, there are
    # genes that are mapped both to chrX and chrY.
    for transcript in transcripts:
        accession = transcript.accession
        transcript_type = transcript.transcript_type
        chromosome = transcript.chromosome
        strand = transcript.strand
        if strand=='+':
            exons = transcript.exons
        elif strand=='-':
            exons = transcript.exons[::-1]
        start, end = exons[0]
        for exon in exons[1:]:
            assert exon[0] >= start
            assert exon[1] >= end
            end = exon[1]
        attributes = ['GeneID=%s' % gene_id,
                      'GeneName=%s' % gene_name,
                      'Accession=%s' % accession,
                      'GeneType=%s' % gene_type,
                      'TranscriptType=%s' % transcript_type,
                      'Description=%s' % description]
        fields = [chromosome, 'NCBI', 'transcript', str(start+1), str(end), ".", strand, ".", "; ".join(attributes)]
        interval = pybedtools.create_interval_from_list(fields)
        yield interval

def write_transcript_file(organism, source=None):
    mirnas = []
    primirnas = []
    intervals = []
    if source is None:
        source = "%s.ags.gz" % organism.replace(" ", "_")
    command = ["gene2xml", "-b", "T", "-c", "T", "-i", source]
    print("Running %s" % " ".join(command))
    handle = subprocess.Popen(command, stdout=subprocess.PIPE).stdout
    records = Entrez.parse(handle)
    for record in records:
        transcripts = find_transcript_locations(record, organism, mirnas, primirnas)
        if transcripts is None:
            continue
        gene_type = transcripts.gene_type
        if gene_type in ('snRNA', 'scRNA', 'snoRNA'):
            continue
        gene_id = transcripts.gene_id
        gene_name = transcripts.gene_name
        description = transcripts.description
        # A gene can be mapped to multiple locations. In particular, there are
        # genes that are mapped both to chrX and chrY.
        for transcript in transcripts:
            accession = transcript.accession
            transcript_type = transcript.transcript_type
            chromosome = transcript.chromosome
            strand = transcript.strand
            if strand=='+':
                exons = transcript.exons
            elif strand=='-':
                exons = transcript.exons[::-1]
            start, end = exons[0]
            for exon in exons[1:]:
                assert exon[0] >= start
                assert exon[1] >= end
                end = exon[1]
            attributes = ['GeneID=%s' % gene_id,
                          'GeneName=%s' % gene_name,
                          'Accession=%s' % accession,
                          'GeneType=%s' % gene_type,
                          'TranscriptType=%s' % transcript_type,
                          'Description=%s' % description]
            fields = [chromosome, 'NCBI', 'transcript', str(start+1), str(end), ".", strand, ".", "; ".join(attributes)]
            interval = pybedtools.create_interval_from_list(fields)
            intervals.append(interval)
    handle.close()
    intervals.sort(key=order)
    filename = "genes.gff"
    print("Writing", filename)
    handle = open(filename, 'w')
    time = timestamp(source)
    handle.write("##source-version NCBI-Entrez:%s %s\n" % (source, time))
    for interval in intervals:
        line = str(interval)
        handle.write(line)
    handle.close()
    for primirna in primirnas:
        mirna = primirna[:-2]
        if mirna not in mirnas:
            print("primiRNA %s without a mirna" % primirna)

def write_exon_file(organism, source=None):
    if source is None:
        source = "%s.ags.gz" % organism.replace(" ", "_")
    command = ["gene2xml", "-b", "T", "-c", "T", "-i", source]
    print("Running %s" % " ".join(command))
    input = subprocess.Popen(command, stdout=subprocess.PIPE).stdout
    output = open("exons.gff", 'w')
    time = timestamp(source)
    output.write("##source-version NCBI-Entrez:%s %s\n" % (source, time))
    # A gene can be mapped to multiple locations.
    # Use a counter to distinguish all mappings
    counter = 0
    records = Entrez.parse(input)
    for record in records:
        transcripts = find_transcript_locations(record, organism)
        if transcripts is None:
            continue
        gene_id = transcripts.gene_id
        gene_name = transcripts.gene_name
        for transcript in transcripts:
            accession = transcript.accession
            transcript_type = transcript.transcript_type
            chromosome = transcript.chromosome
            strand = transcript.strand
            exons = transcript.exons
            n = len(exons)
            for i in range(n):
                start, end = exons[i]
                if strand=='+':
                    position = i
                elif strand=='-':
                    position = n-1-i
                line = "%s\tNCBI\t%d:%s:%s:%s:%d:%d\t%d\t%d\t.\t%s\t.\t%s\n" % (chromosome, counter, gene_id, gene_name, accession, position, n, start+1, end, strand, transcript_type)
                output.write(line)
            counter += 1
    input.close()
    output.close()

def write_intron_file(organism, source=None):
    if source is None:
        source = "%s.ags.gz" % organism.replace(" ", "_")
    command = ["gene2xml", "-b", "T", "-c", "T", "-i", source]
    print("Running %s" % " ".join(command))
    input = subprocess.Popen(command, stdout=subprocess.PIPE).stdout
    output = open("introns.gff.new", 'w')
    time = timestamp(source)
    output.write("##source-version NCBI-Entrez:%s %s\n" % (source, time))
    # A gene can be mapped to multiple locations.
    # Use a counter to distinguish all mappings
    counter = 0
    records = Entrez.parse(input)
    for record in records:
        transcripts = find_transcript_locations(record, organism)
        gene_id = transcripts.gene_id
        gene_name = transcripts.gene_name
        for transcript in transcripts:
            accession = transcript.accession
            transcript_type = transcript.transcript_type
            chromosome = transcript.chromosome
            strand = transcript.strand
            exons = transcript.exons
            n = len(exons)
            for i in range(n-1):
                start1, end1 = exons[i]
                start2, end2 = exons[i+1]
                if strand=='+':
                    position = i
                    start = end1
                    end = start2
                elif strand=='-':
                    position = n-1-i
                    start = end2
                    end = start1
                line = "%s\tNCBI\t%d:%s:%s:%s:%d:%d\t%d\t%d\t.\t%s\t.\t%s\n" % (chromosome, counter, gene_id, gene_name, accession, position, n, start+1, end, strand, transcript_type)
                output.write(line)
            counter += 1
    input.close()
    output.close()

write_transcript_file(organism)
write_exon_file(organism)
# write_intron_file(organism)
